class: center, middle, inverse, title-slide # Poisson lognormal model ## Optimisation with machine-learning techniques ###
B. Batardière
, J. Chiquet, J. Kwon + PLN team
MIA Paris-Saclay, AgroParisTech, INRAE
Last update 23 March, 2022
--- # PLN Google Search ### Github group  --- # PLN Google Search ### Mind the Spelling!  --- # Poisson lognormal Model .pull-left2[ Let <br/> We observe ] .pull-right2[ - `\(n\)` be the number of sites/cells/samples - `\(p\)` the number of species/genes/variables - `\(d\)` the number of environmental covariates - `\(n\)` measures of the joint counts `\(Y_i\in\mathbb{N}^p\)` - `\(n\)` measures of the environment covariates matrix `\(X_i\in\mathbb{R}^d\)` ] `$$\begin{aligned} Z_{i} &= \beta^{\top} X_i + CW_i, \qquad W_{i} \sim \mathcal{N}\left(0, I_{q}\right) \\[1.25ex] Y_{i j} \mid Z_{i j} & \sim \mathcal{P}\left(\exp \left( Z_{i j}\right)\right) \\ \end{aligned}$$` where `\(q\leq p\)` is the dimension of the latent space. The model parameters encompass - The matrix of regression parameters `\(\beta = (\beta _{kj})_{1 \leq k \leq d, 1 \leq j \leq p}\)`, - The matrix `\(C \in \mathbb R^{p\times q}\)` sending the latent variable `\(W_i\)` from `\(\mathbb{R}^q\)` to `\(\mathbb{R}^p\)`. If `\(p = q\)`, `\(\theta = (\beta, \Sigma = C^T C)\)`, .important[standard PLN], see [AH89; CMR21] If `\(q < p\)`, `\(\theta = (\beta, C)\)`, .important[PLN-PCA], see [CMR18] --- # .small[Example for visualization of single-cell data] A dataset containing the counts of the 500 most varying transcripts in the mixtures of 5 cell lines in human liver, for a total of 3918 cells <img src="data:image/png;base64,#optimPLN_files/figure-html/unnamed-chunk-2-1.png" width="50%" style="display: block; margin: auto;" /> `\(\rightsquigarrow\)` `R` implementation with V-EM and standard non linear optimization techniques works up to hundreds of variables and thousands of samples. Can we do better? .important[ML + Pytorch] [Pas+17] --- # Inference We wish to solve `\(\hat{\theta} = \arg\max _{\theta} \sum _{i =1} ^n \log p_{\theta}(Y_i)\)` ### .content-box-red[.small[Approximated Expectation-Maximization (Variational EM), [BKM19; CMR18]]] .content-box-yellow[ `$$\arg\max_{\theta, \, q\in\mathcal{Q}} \; \sum_{i=1}^n J_i(\theta, \psi_i) = \sum_{i=1}^n \log p_{\theta}(Y_i)-K L\left[q(W_i) \| p_{\theta}(W_i \mid Y_i)\right]$$` where `\(J_i(\theta, \psi) = \mathbb{E}_{q_i ;\theta^{t}}\left[\log p_{\theta}(Y_i, W_i) \mid Y_i\right] - \mathbb{E}_{q_i ;\theta^{t}} [\log q_i(W_i;\psi_i)]\)` and `\(q_i(.; \psi_i) \in \mathcal{Q} = \{ \mathcal{N}\left(M_{i}, \operatorname{diag} (S_{i}\odot S_i )), M_i \in \mathbb{M} ^q, S_i \in \mathbb{R}^q\right\}\)` ] -- ### .content-box-red[.small[Direct optimization by approximating the gradient of the objective]] .content-box-yellow[ `$$\nabla_{\theta} \sum _{i =1} ^n \log p_{\theta}(Y_i) = \sum _{i =1} ^n \nabla_{\theta} \log \left( \int_{R^q} p_{\theta} (Y_i|W_i) p(W_i)\mathrm{d} W_i \right)$$` - Ingredient 1: Stochastic-gradient method with variance reduction - Ingredient 2: Monte-Carlo estimation and importance sampling + Pytorch ] --- # 1. Variational EM ### PLN `$$Y_{i} \mid Z_{i} \sim \mathcal{P}\left(\exp \left( Z_{i}\right)\right), Z_{i} \sim\mathcal{N}(\beta^{\top} X_i, \Sigma)$$` - M step is explicit : for fix `\(\psi_i = (M_i,S _i)\)`, then `$$\hat{\Sigma} = \frac{1}{n} \sum_{i}\left(\left((M^{(t)}-X\beta)_{i} (M^{(t)}-X\beta)_{i}\right)^{\top}+S^{(t)}_{i}\right), \quad\hat{\beta} = (X^{\top}X)^{-1}X^{\top}M^{(t)}$$` - E-step: for fix `\(\theta\)` to `\(\hat{\theta}\)`, solve in `\(\psi_i\)` by gradient ascent ### PLN-PCA Joint gradient ascent on `\((\theta, \psi) = (C, \beta, {M_i, S_i, i=1,\dots,n})\)` ### Tuning - Use Rprop [RB93] to solve the gradient ascent - SGD with momentum/adaptive learning rate - Use the sign of the gradient + one learning rate per parameter --- # Performance of V-EM <div class="figure" style="text-align: center"> <img src="data:image/png;base64,#Comparison_fastPLN_vs_fastPLNPCA_n=1000.png" alt="Running times for `\(n=1000, q=10, d=1\)`." width="60%" /> <p class="caption">Running times for `\(n=1000, q=10, d=1\)`.</p> </div> .pull-left2[ PLN PLN-PCA ] .pull-right2[ - convergence in a small number of iterations - `\(\mathcal{O}(n p + p^2)\)` parameters to optimize + inversion of `\(\hat{\Sigma} (p\times p)\)` - convergence for a large number of iterations, - `\(\mathcal{O}(n p + p q)\)` parameters to optimize + inversion of `\(\hat{\Sigma} (q\times q)\)` ] --- # 2. Direct Gradient approx: .small[first ingredient] Optimize a ( `\(\mu\)`-strongly) convex function `\(f(\theta) = \sum_{i} f_i(\theta)\)`, `\(L\)`-gradient Lipschitz. #### Stochastic Gradient Descent (SGD) `$$\begin{aligned} \text{For } & t=1,\dots,T \\ & i_t\sim\mathcal{U}([1,..,n]) \\ & \theta_{t+1} \leftarrow \theta - \eta \nabla f_{i_t}(\theta_t)\\ \end{aligned}$$` -- #### SGD with Variance Reduction (SAGA, [DBL14]) `$$\begin{aligned} \text{For } & t=1,\dots,T \\ & i_t\sim\mathcal{U}([1,..,n])\\ & \theta_{t+1} \leftarrow \theta - \eta \left(\nabla f_{i_t}(\theta_t) - \nabla f_{i_t}(\alpha^t_{i_t}) + \frac{1}{n} \sum_{i=1}^n \nabla f_i (\alpha_i^t) \right) \\ & \text{For } i=1,\dots,n \\ & \quad \alpha_i^{t+1} \leftarrow \mathbf{1}_{\{i_t = i\}} \theta_t + \mathbf{1}_{\{i_t \neq i\}} \alpha_i^t \\ \end{aligned}$$` - stabilize the estimation of the gradient by averaging - need to store all the gradients - other variants (SVRG, [+19]) --- # 2. Direct Gradient approx: .small[second ingredient] For our average SGD algorithm, we need to estimate `\(p_{\theta}(Y_i) = \mathbb{E}_{W} \left[p_\theta(Y_i|W) \right]\)`. Let `\(\tilde p_{\theta} = p_{\theta}(Y_i| W_i) p(W_i)\)` (explicit) and `\(n_s\)` the sampling effort. MC approach solves `$$p_{\theta}(Y_i) = \int \tilde p_{\theta}(W_i) \mathrm dW_i \approx \frac 1 {n_s} \sum_{k=1}^{n_s} \tilde p_{\theta}, \text{ where } W_i \sim\mathcal{N}(0,I_q)$$` `\(\rightsquigarrow\)` `\(\tilde p_{\theta}\)` has little mass near zero, huge variance, poor approximation. -- ### Importance Sampling Consider a density function `\(\phi\)` (the importance law). Importance sampling relies on `$$\mathbb{E}_{W} \left[p_\theta(Y_i|W) \right] = \mathbb{E}_\phi \left[ \frac{p_\theta(Y_i|W) p(W)}{\phi(W)}\right].$$` To estimate `\(p_\theta\)` as follows: `$$p_{\theta}(Y_i) = \int \tilde p_{\theta}(W) \mathrm dW \approx \frac 1 {n_s} \sum_{k=1}^{n_s} \frac {\tilde p_{\theta}(V_k)}{\phi(V_k)}, \quad (V_{k})_{1 \leq k \leq n_s} \overset{iid}{\sim} \phi(.)$$` --- # Choice of the importance law ### Optimal choice The IS estimator is consistent, with variance depending on `\(\phi\)`. The optimal choice is such as `$$\phi_{opt}(V_i) = \frac{\tilde{p}_{\theta}(V_i)}{p_\theta(Y_i)} = p(V_i|Y_i)$$` -- ### Practical choice We chose `\(\phi(V) \sim \mathcal{N}(\mu, \Sigma)\)` with `$$\hat{\mu} = \arg\max_\mu \log p(W_i|Y_i) = \arg\max_\mu \log \tilde{p}(W_i)$$` that we solve numerically (thank you Pytorch). For the Covariance, `$$\hat{\Sigma}^{-1} = \frac{\partial^2}{\partial^2 W_i} \log \tilde{p}(W_i)\Bigg\vert_{W=\hat{\mu}} = - I_q - C^\top \text{diag}(\exp(X_i^\top \beta + C W)) C\Bigg\vert_{W=\hat\mu}$$` --- # Gradient approximation `$$\nabla _{\theta} \operatorname{log} p_{\theta}(Y_i) \approx \nabla_{\theta} \operatorname{log}\left(\frac 1 {n_s} \sum_{k=1}^{n_s} \frac {\tilde p_{\theta}^{(u)}(V_k)}{g(V_k)}\right)$$` We have $$ \tilde{p}_{\theta}(W_i) = \exp \left( - \frac 12 \| W_i\|^2 - \mathbf{1}_p^{\top} \exp(O_i + \beta^{\top}X_i + CW_i) + Y_i^{\top}(O_i + \beta^{\top}X_i +CW_i)\right),$$ and derive the gradients formula `$$\nabla_{\beta} \log p_{\theta}(Y_i)\approx X_i Y_i^{\top} -\frac{\sum_{i = 1}^{n_s}\frac{\tilde p_{\theta}(V_k)}{\phi(V_k)}X_i\exp(O_i + \beta^{\top}X_i + C V_k)^{\top}}{\sum_{i = 1}^{n_s}\frac{\tilde p_{\theta}(V_k)}{\phi(V_k)}}$$` `$$\nabla_{C} \log p_{\theta}(Y_i)\approx \frac{\sum_{i = 1}^{n_s}\frac{\tilde p_{\theta}(V_k)}{\phi(V_k)}\left[Y_{i}- \exp \left(O_i + \beta^{\top} X_{i}+C V_{k}{ }\right)\right] V_{k}^{\top}}{\sum_{i = 1}^{n_s}\frac{\tilde p_{\theta}(V_k)}{\phi(V_k)}}$$` Given the estimated gradients, we can run a gradient ascent to increase the likelihood. We use SGD with Variance reduction such as SAGA, mini-batches, etc. --- # Performance of Importance Sampling (1) Varying `\(p\)` <div class="figure" style="text-align: center"> <img src="data:image/png;base64,#Convergence_analysis_IMPS_PLN_n=300,q=10,250_iterations.png" alt="Running times for `\(n=300, q=10, d=1\)`, 250 iterations." width="80%" /> <p class="caption">Running times for `\(n=300, q=10, d=1\)`, 250 iterations.</p> </div> --- # Performance of Importance Sampling (2) Varying `\(q\)` <div class="figure" style="text-align: center"> <img src="data:image/png;base64,#Convergence_analysis_IMPS_PLN_n=300,p=2000,250_iterations.png" alt="Running times for `\(n=300, p=2000, d=1\)`, 250 iterations." width="80%" /> <p class="caption">Running times for `\(n=300, p=2000, d=1\)`, 250 iterations.</p> </div> --- # V-EM vs Importance Sampling <br/> Example with `\(n=p=1000\)`, `\(d=1, q=10\)`, Toeplitz (AR-like) covariance <br/> <img src="data:image/png;base64,#ELBOvslikelihood.png" width="110%" style="display: block; margin: auto;" /> <br/> - orange: ELBO of the V-EM - red: log-likelihood found by IMPS at convergence - blue: log-likelihood computed with current V-EM estimates --- # PLN Python package ### Installation Clone the github repository ```bash git clone https://github.com/pln-team/pyPLNmodels ``` Create an environment ```bash conda create --name pyPLNmodels conda activate pyPLNmodels ``` Install Pytorch (see [https://pytorch.org/](https://pytorch.org/)) ```bash conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch ``` Install the module `pyPLNmodels` with dependencies ```bash pip install pyPLNmodels ``` --- # PLN Python package ### Load the package Amaizingly works from `RStudio` with **reticulate** correctly configured: ```r library(reticulate) use_condaenv("pyPLNmodels") ``` Once included, then in a Python chunk, you need ```python from pyPLNmodels.models import PLNmodel, fastPLN, fastPLNPCA, IMPS_PLN ``` ``` ## Device cpu ``` ### Load the data In Python, (data set with `\(p = 50, n = 200, d = 2\)`) ```python import pandas as pd Y = pd.read_csv("Y.csv") X = pd.read_csv("X.csv") O = pd.read_csv("O.csv") ``` --- # Fast PLN ```python mypln = fastPLN() mypln.fit(Y, O , X) ``` <img src="data:image/png;base64,#example_pln_criterion.png" width="60%" style="display: block; margin: auto;" /> --- # Fast PLN-PCA ```python myplnpca = fastPLNPCA(q=5) myplnpca.fit(Y, O, X) ``` <img src="data:image/png;base64,#example_plnpca_criterion.png" width="60%" style="display: block; margin: auto;" /> --- # Importance Sampling for PLN-PCA ```python imps = IMPS_PLN(q=5) imps.fit(Y, O, X) ``` <img src="data:image/png;base64,#example_imps_criterion.png" width="60%" style="display: block; margin: auto;" /> --- # Thank you! ### In progress - With Bastien/Joon: convergence study of SAGA/SVRG + Adagrad/Adam - Functional Version of PLN-PCA (spatio-temporal scRNA data); ZI-PLN (done) - Statistical guarantees of V-EM estimates -> Mahendra  --- # References Aitchison, J. and C. H. Ho (1989). "The multivariate Poisson-log normal distribution". In: _Biometrika_ 76.4, pp. 643-653. Blei, D. M., A. Kucukelbir, and J. D. McAuliffe (2019). "Variational Inference: A Review for Statisticians". In: _JASA_. Chiquet, J., M. Mariadassou, and S. Robin (2018). "Variational inference for probabilistic Poisson PCA". In: _Annals of Applied Stat._. Chiquet, J., M. Mariadassou, and S. Robin (2021). "The Poisson-lognormal model as a versatile framework for the joint analysis of species abundances". In: _Frontiers in Ecology and Evolution_ 9, p. 188. Defazio, A., F. Bach, and S. Lacoste-Julien (2014). "SAGA: A fast incremental gradient method with support for non-strongly convex composite objectives". In: _Advances in neural information processing systems_ 27. Ge, R., Z. Li, W. Wang, et al. (2019). "Stabilized SVRG: Simple variance reduction for nonconvex optimization". In: _Conference on learning theory_. PMLR. , pp. 1394-1448. Paszke, A., S. Gross, S. Chintala, et al. (2017). "Automatic differentiation in PyTorch". Riedmiller, M. A. and H. Braun (1993). "A direct adaptive method for faster backpropagation learning: the RPROP algorithm". In: _IEEE International Conference on Neural Networks_, pp. 586-591 vol.1.